Skip to content

feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706

Open
Dayuxiaoshui wants to merge 4 commits intoflashinfer-ai:mainfrom
Dayuxiaoshui:main
Open

feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706
Dayuxiaoshui wants to merge 4 commits intoflashinfer-ai:mainfrom
Dayuxiaoshui:main

Conversation

@Dayuxiaoshui
Copy link
Copy Markdown

@Dayuxiaoshui Dayuxiaoshui commented Mar 6, 2026

📌 Description

Implements RFC 5.7/5.8: add unified GDN decode API gated_delta_rule_decode_unified (state_layout, state_indices, intermediate_states_buffer, disable_state_update), rename backends to _*_impl with dispatch by layout/dtype/T, add deprecation shims for the three legacy decode APIs, export in init.py, and add tests/gdn/test_gdn_decode_unified.py (cross-check, intermediate_buffer, edge cases, error paths).

🔍 Related Issues

Closes #2687

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Unified decode API that selects the appropriate backend at runtime.
  • Deprecations

    • Backend-specific decode variants deprecated; legacy shims emit warnings.
  • Improvements

    • Stronger validation with clearer error messages for invalid state/layout/parameter combos.
    • Expanded test coverage and updated benchmarks to verify parity across implementations.

- Add gated_delta_rule_decode_unified as single entry point with state_layout (VK/KV), state_indices, intermediate_states_buffer, disable_state_update.

- Rename existing implementations to _*_impl; unified dispatches by state_layout, state dtype, and T.

- Add deprecation shims for gated_delta_rule_decode_pretranspose, gated_delta_rule_decode (KV), gated_delta_rule_mtp with DeprecationWarning.

- Export gated_delta_rule_decode_unified in __init__.py.

- Add tests/gdn/test_gdn_decode_unified.py: cross-check vs _*_impl, intermediate_states_buffer, edge cases (pool_size=1, B=1), error paths.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines the Gated Delta Rule (GDN) decode functionality by introducing a unified API. This change centralizes the dispatch logic for different GDN decode backends, improving maintainability and future extensibility. Concurrently, it gracefully transitions users away from older, specialized decode functions by providing deprecation shims, ensuring a smooth migration path while consolidating the API surface.

Highlights

  • Unified GDN Decode API: Introduced a new unified GDN decode API, gated_delta_rule_decode_unified, to serve as a single entry point for various decode operations (T=1 and MTP).
  • Backend Renaming: Renamed existing GDN decode backend functions (e.g., gated_delta_rule_decode_pretranspose) to internal _*_impl versions to support the new unified dispatch mechanism.
  • Deprecation Shims: Implemented deprecation shims for the three legacy GDN decode APIs, guiding users to the new unified API and issuing DeprecationWarnings.
  • New Test Coverage: Added a new test file, test_gdn_decode_unified.py, to thoroughly validate the functionality of the unified API, including cross-checking with legacy behavior, intermediate buffer handling, edge cases, and error paths.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/init.py
    • Exported the new gated_delta_rule_decode_unified function.
  • flashinfer/gdn_decode.py
    • Imported the warnings module for deprecation notices.
    • Renamed gated_delta_rule_decode_pretranspose to _gated_delta_rule_decode_pretranspose_impl.
    • Renamed gated_delta_rule_decode to _gated_delta_rule_decode_kv_impl.
    • Renamed gated_delta_rule_mtp to _gated_delta_rule_mtp_impl.
    • Added the gated_delta_rule_decode_unified function, which dispatches to the appropriate _impl backend based on parameters.
    • Re-added gated_delta_rule_decode_pretranspose, gated_delta_rule_decode, and gated_delta_rule_mtp as deprecation shims that call the unified API and issue warnings.
  • tests/gdn/test_gdn_decode_unified.py
    • Added a new test file to verify the gated_delta_rule_decode_unified API.
    • Included tests to ensure the unified API matches the behavior of the now-internal _impl functions for various configurations (VK bf16 T=1, VK fp32 T=1, KV fp32 T=1, VK fp32 MTP).
    • Added tests for intermediate state buffer handling in MTP.
    • Included tests for edge cases like batch size 1 and pool size 1.
    • Added tests for error handling of unsupported state_layout values, state_indices with KV layout, and VK fp32 T>1 without a pool.
Activity
  • The pull request introduces a new unified API for GDN decode operations, along with deprecation shims for older functions.
  • Existing backend implementations were refactored and renamed to internal functions.
  • Comprehensive tests were added to validate the new unified API's correctness and error handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

Adds a unified gated_delta_rule_decode dispatcher in flashinfer/gdn_decode that routes VK/KV/MTP backends based on state_layout, dtype, T, and state_indices; renames previous public backend functions to internal _*_impl variants and provides deprecation shim wrappers. Exposes the unified symbol at package top-level and updates benchmarks/tests to call the appropriate backend-specific variant.

Changes

Cohort / File(s) Summary
Unified Decode API
flashinfer/gdn_decode.py
Added public dispatcher gated_delta_rule_decode(...) that validates inputs, dispatches to internal implementations based on state_layout, dtype, T, and state_indices; added _check_state_indices_bounds; renamed original public functions to _..._impl and added deprecation shim wrappers.
Package Exports
flashinfer/__init__.py
Re-exported gated_delta_rule_decode from flashinfer.gdn_decode in the package initializer so it appears in the top-level namespace.
Benchmarks
benchmarks/bench_gdn_decode.py
Switched nontranspose/KV benchmark paths to import and call legacy KV-specific API (gated_delta_rule_decode_kv) instead of the unified name.
Tests — single test update
tests/gdn/test_decode_delta_rule.py
Updated import and call site to use gated_delta_rule_decode_kv for the nontranspose test path.
Tests — new coverage
tests/gdn/test_gdn_decode.py
Added comprehensive tests covering VK/KV/MTP parity, dtype variants (bf16/fp32), pool/index validation, error conditions, and deprecation shim behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant UnifiedDecode as gated_delta_rule_decode
    participant Validate as _check_state_indices_bounds
    participant VK as _gated_delta_rule_decode_pretranspose_impl
    participant KV as _gated_delta_rule_decode_kv_impl
    participant MTP as _gated_delta_rule_mtp_impl

    Caller->>UnifiedDecode: call(state_layout, state, state_indices, tokens, ...)
    alt state_indices provided
        UnifiedDecode->>Validate: _check_state_indices_bounds(state_indices, pool_size)
        Validate-->>UnifiedDecode: ok / raise ValueError
    end

    alt state_layout == "KV" and T == 1
        UnifiedDecode->>KV: dispatch to KV impl
        KV-->>UnifiedDecode: return (out, new_state)
    else state_layout == "VK" and bf16 and T > 1 and pool used
        UnifiedDecode->>MTP: dispatch to MTP impl
        MTP-->>UnifiedDecode: return (out, new_state)
    else state_layout == "VK"
        UnifiedDecode->>VK: dispatch to VK/pretranspose impl
        VK-->>UnifiedDecode: return (out, new_state)
    else
        UnifiedDecode-->>Caller: raise ValueError / NotImplementedError
    end

    UnifiedDecode-->>Caller: final result or error
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

ready

Suggested reviewers

  • yzh119
  • bkryu
  • kaixih
  • cyx-6
  • kahyunnam
  • nvmbreughe

Poem

🐰 I hopped through code with nimble feet,
One decode to rule each VK/KV beat,
Old names bow, new routes take flight,
Pools and indices tucked in tight,
A carrot-ready merge — bright night. 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding a unified decode API with deprecation shims, and references the specific RFCs (5.7, 5.8) being implemented.
Description check ✅ Passed The description includes all required template sections: a clear overview of changes, linked issue reference, and pre-commit/test checklists. Implementation details and objectives are well-documented.
Linked Issues check ✅ Passed The PR successfully implements the primary objectives from #2687: unified gated_delta_rule_decode API with state_layout/state_indices parameters, dispatch by dtype/layout/T, deprecation shims for legacy APIs, and comprehensive test coverage for edge cases and error paths.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing the unified GDN decode API: renaming backend implementations, adding the unified router, deprecation shims, exports, and tests. No unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new unified API gated_delta_rule_decode_unified for Gated Delta Rule (GDN) decode operations, simplifying the API surface by renaming old backend implementations, adding deprecation shims, and including comprehensive tests. However, a security audit identified critical issues related to insecure input validation using assert statements and missing bounds checks for user-supplied indices, which could lead to out-of-bounds memory access on the GPU. These assert statements should be replaced with explicit checks for better robustness.

Comment thread flashinfer/gdn_decode.py Outdated
Comment thread flashinfer/gdn_decode.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2662-2679: The unified public entrypoint
gated_delta_rule_decode_unified is missing the required backend capability
guard; wrap it with the `@backend_requirement` decorator and ensure the code calls
the helper checks (is_compute_capability_supported(cc) and
is_backend_supported()) for the SM90+ requirement so callers fail fast on
unsupported GPUs; apply the same change to the other related unified API(s)
around the other decode function(s) referenced (the functions at the diff around
the 2730–2732 region) so both use the backend guard and support helpers
consistently.
- Around line 2704-2706: state_indices can be negative (padding) but the MTP
path/kernel currently skips padded rows leaving preallocated output rows stale;
update the call site that forwards state_indices into the MTP path to either (A)
validate and reject negative indices by raising an error if any(state_indices <
0) when an output buffer is supplied, or (B) proactively zero the corresponding
output rows before launching the MTP/kernel when output is provided;
specifically, check state_indices for negatives, and if choosing (B) zero
output[mask] (where mask = state_indices < 0) prior to the mtp kernel launch so
padded rows do not retain old values. Ensure this change is applied both where
state_indices is forwarded into the MTP path and in the analogous block handling
lines around the second occurrence mentioned (the other MTP call).
- Around line 2895-2930: The shim gated_delta_rule_decode_pretranspose lost
legacy validation: restore checks so callers must provide either state (state)
or the per-step initial state pair (initial_state and initial_state_indices) but
not both, and if initial_state is provided then initial_state_indices must also
be provided (and conversely initial_state_indices must be None when
initial_state is None); add explicit ValueError(s) with clear messages before
delegating to gated_delta_rule_decode_unified to prevent bad calls (refer to
symbols initial_state, initial_state_indices, state,
gated_delta_rule_decode_pretranspose, and gated_delta_rule_decode_unified).
- Around line 2771-2815: The BF16 branch currently always forwards the state
into _gated_delta_rule_decode_pretranspose_impl (via
initial_state/initial_state_indices or state), which prevents honoring
disable_state_update and prevents filling intermediate_states_buffer; change the
BF16 branch so that if disable_state_update is True you do not pass the current
state (pass state=None and omit initial_state/initial_state_indices) and if
intermediate_states_buffer is requested do not use the unified BF16 path but
fall back to the non-bf16/FP32 decode path that supports population of the
rollback buffer (or otherwise implement buffer population), using the same
checks around state.shape/state_indices but routing to the alternative code path
instead of always calling _gated_delta_rule_decode_pretranspose_impl.

In `@tests/gdn/test_gdn_decode_unified.py`:
- Around line 215-357: The tests miss exercising the bfloat16 dispatcher path
and fail to assert intermediate buffer rollback; update the param setup so at
least one T>1 case uses state_pool and intermediate buffers with
dtype=torch.bfloat16 (so gated_delta_rule_decode_unified hits the state.dtype ==
torch.bfloat16 branch) and add an assertion comparing intermed_unified to
intermed_legacy (e.g., torch.testing.assert_close(intermed_unified,
intermed_legacy, atol=5e-3, rtol=5e-3)) in
test_unified_vk_fp32_mtp_with_intermediate_buffer_matches_mtp; ensure the same
dtype change is applied consistently for pool_unified/pool_legacy and for any
places that construct intermed_buf so the BF16 dispatcher path in
gated_delta_rule_decode_unified and rollback caching behavior are both
exercised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ded46dee-cac3-4404-85bf-c183c9d98073

📥 Commits

Reviewing files that changed from the base of the PR and between 124a2d3 and 772f477.

📒 Files selected for processing (3)
  • flashinfer/__init__.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_gdn_decode_unified.py

Comment thread flashinfer/gdn_decode.py
Comment thread flashinfer/gdn_decode.py Outdated
Comment thread flashinfer/gdn_decode.py
Comment thread flashinfer/gdn_decode.py
Comment thread tests/gdn/test_gdn_decode.py
@Dayuxiaoshui
Copy link
Copy Markdown
Author

@kaixih Quick note on naming: the RFC suggests the unified entry point be called gated_delta_rule_decode, but that name is already used by the existing KV-layout API. To avoid breaking callers, this PR introduces the unified API as gated_delta_rule_decode_unified and keeps the old names as deprecation shims. If you prefer to switch to gated_delta_rule_decode for the unified API (and e.g. rename the current KV-layout one or deprecate it under a different name), I can follow up with a small rename patch. Please advise.

@kaixih
Copy link
Copy Markdown
Collaborator

kaixih commented Mar 7, 2026

On naming: we'd recommend the following approach rather than keeping _unified in the name (which is an implementation detail callers shouldn't see):

  1. Rename the existing gated_delta_rule_decode (KV layout) → gated_delta_rule_decode_kv: this is descriptive, clearly signals it's the legacy KV-layout-specific path, and the function is not currently exported from flashinfer/init.py so the blast radius is small.
  2. Use gated_delta_rule_decode for the unified API: the clean name with no suffix, as the RFC originally proposed.
  3. Keep the old gated_delta_rule_decode_pretranspose and gated_delta_rule_mtp as deprecation shims (as you've already done), and add gated_delta_rule_decode_kv as a shim for the renamed KV function.

@kahyunnam @yzh119 how do you think?

@kaixih kaixih mentioned this pull request Mar 7, 2026
43 tasks
@Dayuxiaoshui
Copy link
Copy Markdown
Author

@kaixih I agree with this viewpoint, and I will make corrections and resubmit.

@Dayuxiaoshui
Copy link
Copy Markdown
Author

@kaixih Thanks, done. We renamed the KV path to gated_delta_rule_decode_kv, use gated_delta_rule_decode for the unified API, and kept the three shims (pretranspose, kv, mtp) as you suggested.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
flashinfer/gdn_decode.py (4)

2674-2691: ⚠️ Potential issue | 🟠 Major

Add the standard backend requirement guard to the public decode API.

This is a public SM90+ entrypoint, but unsupported devices still fall through to JIT compilation and fail late instead of getting the usual capability check up front. The same guard should stay consistent across the public shims below. As per coding guidelines, "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2674 - 2691, The public decoder
function gated_delta_rule_decode lacks the standard backend capability guard;
add the `@backend_requirement` decorator above the gated_delta_rule_decode
definition and ensure it calls the module's is_compute_capability_supported(cc)
and is_backend_supported() helpers so unsupported devices fail early; also add
the same `@backend_requirement` to the other public SM90+ shim functions in this
file to keep guards consistent across the public decode APIs.

2926-2933: ⚠️ Potential issue | 🟠 Major

Reject ambiguous state + initial_state calls in the shim.

The restored validation still allows both to be passed together; the wrapper just takes the pool path and silently ignores state. Legacy callers used to get a clear error for that ambiguous combination.

Suggested guard
     use_pool = initial_state is not None
     if use_pool != (initial_state_indices is not None):
         raise ValueError(
             "initial_state and initial_state_indices must be provided together"
         )
+    if state is not None and initial_state is not None:
+        raise ValueError("state and initial_state are mutually exclusive")
     if state is None and initial_state is None:
         raise ValueError("Either state or initial_state must be provided")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2926 - 2933, The current shim silently
prefers the pooling path when both state and initial_state are passed; change
the validation in the block around use_pool to explicitly reject the ambiguous
combination by raising a ValueError if both state is not None and initial_state
is not None (or equivalently if use_pool and state is not None). Update the
checks involving use_pool, initial_state_indices, initial_state and state (the
variables/functions referenced: use_pool, initial_state_indices, initial_state,
state) so that: 1) passing initial_state requires initial_state_indices, 2)
passing both state and initial_state raises ValueError, and 3) the existing
"Either state or initial_state must be provided" behavior remains intact.

2785-2838: ⚠️ Potential issue | 🟠 Major

BF16 T>1 direct-state calls still ignore rollback knobs.

When state.dtype == torch.bfloat16, T>1, and state_indices is None, this branch still forwards to _gated_delta_rule_decode_pretranspose_impl() without checking disable_state_update or intermediate_states_buffer. Both arguments are silently ignored, so callers can request read-only execution or rollback caching and still get in-place mutation with no cached states.

Suggested guard
     if state.dtype == torch.bfloat16:
+        if T > 1 and (
+            disable_state_update or intermediate_states_buffer is not None
+        ):
+            raise NotImplementedError(
+                "VK bf16 T>1 does not support disable_state_update or "
+                "intermediate_states_buffer yet"
+            )
         if T not in (1, 2, 3, 4) or K != 128 or V != 128:
             raise ValueError(
                 f"VK bf16 path requires T in {{1,2,3,4}} and K=V=128, got T={T}, K={K}, V={V}"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2785 - 2838, The BF16 non-pool branch
currently ignores rollback/read-only knobs: when state.dtype == torch.bfloat16
and use_pool is False (the branch that validates state.shape and calls
_gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard
as the pool path to check if disable_state_update is True or
intermediate_states_buffer is not None and raise NotImplementedError (with the
same or similar message instructing to use fp32 state for MTP); ensure this
check occurs before validating state.shape and before calling
_gated_delta_rule_decode_pretranspose_impl so callers cannot request
read-only/rollback behavior that will be silently ignored.

2662-2671: ⚠️ Potential issue | 🟠 Major

Negative state_indices are rejected instead of treated as padding.

RFC 5.7/5.8 calls out negative-index padding semantics, but this helper raises on every negative entry. The unified API therefore still can't represent padded rows, and the new tests now lock in the opposite contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2662 - 2671, The helper
_check_state_indices_bounds currently treats negative state_indices as errors
but per RFC 5.7/5.8 negatives are padding and should be ignored; change the
validation to only validate non-negative indices: if state_indices.numel() == 0
return; build a non_negative mask = state_indices >= 0 and if no non-negative
entries return; compute bad = non_negative & (state_indices >= pool_size) and if
bad.any() raise ValueError with the first out-of-range value from
state_indices[bad]; keep references to the same symbols
(_check_state_indices_bounds, state_indices, pool_size, bad) so the change is
local and preserves behavior for real indices while allowing negative padding.
tests/gdn/test_gdn_decode.py (1)

216-280: ⚠️ Potential issue | 🟠 Major

Add one BF16 T>1 pool parity case.

All MTP parity tests here still build torch.float32 state pools, so gated_delta_rule_decode() never exercises its state.dtype == torch.bfloat16 VK branch for T>1. That leaves a distinct dispatcher/backend path unverified even though this PR adds BF16 routing there.

Also applies to: 283-360

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_decode.py` around lines 216 - 280, The test file builds
state_pool as torch.float32 so gated_delta_rule_decode never exercises its
bfloat16 VK branch for T>1; add an additional parameterized parity case where
state_pool (and pool_legacy/pool_unified) are created with dtype=torch.bfloat16
for T>1 so gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run
with bfloat16 state pools; update the test (functions
test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at
lines 283-360) to include a BF16 variant (or a separate test) that constructs
state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and
pool_unified == pool_legacy with the same tolerances.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2965-2999: The gated_delta_rule_decode_kv function currently emits
a DeprecationWarning and a docstring marking it deprecated; remove the warning
and change the docstring so gated_delta_rule_decode_kv remains a stable,
supported KV-specific entrypoint while still delegating to
gated_delta_rule_decode (keep the call that passes state_layout="KV" and all
parameters intact). Locate the gated_delta_rule_decode_kv function and delete
the warnings.warn block and the “Deprecated” wording in the docstring so callers
migrating to this explicit KV API do not get deprecated warnings.

---

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2674-2691: The public decoder function gated_delta_rule_decode
lacks the standard backend capability guard; add the `@backend_requirement`
decorator above the gated_delta_rule_decode definition and ensure it calls the
module's is_compute_capability_supported(cc) and is_backend_supported() helpers
so unsupported devices fail early; also add the same `@backend_requirement` to the
other public SM90+ shim functions in this file to keep guards consistent across
the public decode APIs.
- Around line 2926-2933: The current shim silently prefers the pooling path when
both state and initial_state are passed; change the validation in the block
around use_pool to explicitly reject the ambiguous combination by raising a
ValueError if both state is not None and initial_state is not None (or
equivalently if use_pool and state is not None). Update the checks involving
use_pool, initial_state_indices, initial_state and state (the
variables/functions referenced: use_pool, initial_state_indices, initial_state,
state) so that: 1) passing initial_state requires initial_state_indices, 2)
passing both state and initial_state raises ValueError, and 3) the existing
"Either state or initial_state must be provided" behavior remains intact.
- Around line 2785-2838: The BF16 non-pool branch currently ignores
rollback/read-only knobs: when state.dtype == torch.bfloat16 and use_pool is
False (the branch that validates state.shape and calls
_gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard
as the pool path to check if disable_state_update is True or
intermediate_states_buffer is not None and raise NotImplementedError (with the
same or similar message instructing to use fp32 state for MTP); ensure this
check occurs before validating state.shape and before calling
_gated_delta_rule_decode_pretranspose_impl so callers cannot request
read-only/rollback behavior that will be silently ignored.
- Around line 2662-2671: The helper _check_state_indices_bounds currently treats
negative state_indices as errors but per RFC 5.7/5.8 negatives are padding and
should be ignored; change the validation to only validate non-negative indices:
if state_indices.numel() == 0 return; build a non_negative mask = state_indices
>= 0 and if no non-negative entries return; compute bad = non_negative &
(state_indices >= pool_size) and if bad.any() raise ValueError with the first
out-of-range value from state_indices[bad]; keep references to the same symbols
(_check_state_indices_bounds, state_indices, pool_size, bad) so the change is
local and preserves behavior for real indices while allowing negative padding.

In `@tests/gdn/test_gdn_decode.py`:
- Around line 216-280: The test file builds state_pool as torch.float32 so
gated_delta_rule_decode never exercises its bfloat16 VK branch for T>1; add an
additional parameterized parity case where state_pool (and
pool_legacy/pool_unified) are created with dtype=torch.bfloat16 for T>1 so
gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run with
bfloat16 state pools; update the test (functions
test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at
lines 283-360) to include a BF16 variant (or a separate test) that constructs
state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and
pool_unified == pool_legacy with the same tolerances.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7258ed1b-b041-4012-9d81-69215acaecd1

📥 Commits

Reviewing files that changed from the base of the PR and between 772f477 and 1a37574.

📒 Files selected for processing (5)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/__init__.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py
  • tests/gdn/test_gdn_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/init.py

Comment thread flashinfer/gdn_decode.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
flashinfer/gdn_decode.py (2)

649-666: 🛠️ Refactor suggestion | 🟠 Major

Add @backend_requirement decorator for SM90+ requirement.

The unified API is documented as requiring SM90+ (line 719), but lacks the @backend_requirement decorator. Without it, callers on unsupported GPUs will fail late during JIT compilation rather than getting a clear capability check upfront.

As per coding guidelines: "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 649 - 666, The gated_delta_rule_decode
API lacks the required SM90+ capability check decorator; add
`@backend_requirement` with the SM90+ requirement above the
gated_delta_rule_decode definition so callers get an upfront capability check.
Ensure the decorator uses the existing helper functions
is_compute_capability_supported(cc) and is_backend_supported() to express the
SM90+ requirement (matching the unified API docs) so the GPU capability/ backend
support is validated before JIT compilation.

940-974: ⚠️ Potential issue | 🟠 Major

Remove deprecation warning from gated_delta_rule_decode_kv.

Per the PR discussion, reviewer kaixih recommended keeping gated_delta_rule_decode_kv as the stable explicit KV-layout entrypoint. This shim was introduced as the new name for the legacy KV path, so deprecating it immediately contradicts the migration guidance. Keep it as a thin delegation wrapper without the warning.

🔧 Suggested fix
 `@flashinfer_api`
 def gated_delta_rule_decode_kv(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     state: torch.Tensor,
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
     b: torch.Tensor,
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Deprecated: use gated_delta_rule_decode(..., state_layout=\"KV\") instead."""
-    warnings.warn(
-        "gated_delta_rule_decode_kv is deprecated and will be removed in a future "
-        "version. Use gated_delta_rule_decode(..., state_layout='KV') instead.",
-        DeprecationWarning,
-        stacklevel=2,
-    )
+    """KV-layout decode API. Delegates to gated_delta_rule_decode with state_layout='KV'."""
     return gated_delta_rule_decode(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 940 - 974, The
gated_delta_rule_decode_kv wrapper currently emits a DeprecationWarning and its
docstring marks it deprecated; remove the warnings.warn call and the
DeprecationWarning/stacklevel arguments and update the docstring to reflect that
this is the stable KV-layout entrypoint, leaving the function body to simply
delegate to gated_delta_rule_decode(..., state_layout="KV") with the same
parameters (q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm).
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)

30-31: Unused functools import.

functools is imported but not used anywhere in the file. Remove it if not needed.

🧹 Suggested fix
-import functools
 import warnings
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 30 - 31, Remove the unused import of
functools from the top of gdn_decode.py: locate the import statement "import
functools" and delete it (ensure no references to functools exist elsewhere in
functions or classes in this file, e.g., any usage in functions or decorators);
after removal, run the linter/tests to confirm no unresolved references remain.

722-723: Unused variable H.

H is unpacked from q.shape but never used in the function body. Consider using _ to indicate it's intentionally ignored.

🧹 Suggested fix
-    B, T, H, K = q.shape
+    B, T, _, K = q.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 722 - 723, The tuple unpacking of
q.shape assigns H but H is never used; change the unpack to ignore that
dimension (e.g., B, T, _, K = q.shape) so the unused variable is explicit, and
ensure no other code relies on H (update any references if present); leave the v
unpack (_, _, HV, V = v.shape) as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 649-666: The gated_delta_rule_decode API lacks the required SM90+
capability check decorator; add `@backend_requirement` with the SM90+ requirement
above the gated_delta_rule_decode definition so callers get an upfront
capability check. Ensure the decorator uses the existing helper functions
is_compute_capability_supported(cc) and is_backend_supported() to express the
SM90+ requirement (matching the unified API docs) so the GPU capability/ backend
support is validated before JIT compilation.
- Around line 940-974: The gated_delta_rule_decode_kv wrapper currently emits a
DeprecationWarning and its docstring marks it deprecated; remove the
warnings.warn call and the DeprecationWarning/stacklevel arguments and update
the docstring to reflect that this is the stable KV-layout entrypoint, leaving
the function body to simply delegate to gated_delta_rule_decode(...,
state_layout="KV") with the same parameters (q, k, v, state, A_log, a, dt_bias,
b, scale, output, use_qk_l2norm).

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 30-31: Remove the unused import of functools from the top of
gdn_decode.py: locate the import statement "import functools" and delete it
(ensure no references to functools exist elsewhere in functions or classes in
this file, e.g., any usage in functions or decorators); after removal, run the
linter/tests to confirm no unresolved references remain.
- Around line 722-723: The tuple unpacking of q.shape assigns H but H is never
used; change the unpack to ignore that dimension (e.g., B, T, _, K = q.shape) so
the unused variable is explicit, and ensure no other code relies on H (update
any references if present); leave the v unpack (_, _, HV, V = v.shape) as-is.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2de30a05-590d-4b0c-b214-d6829960d2bf

📥 Commits

Reviewing files that changed from the base of the PR and between 1a37574 and 1a75db1.

📒 Files selected for processing (3)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • benchmarks/bench_gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py

@Dayuxiaoshui
Copy link
Copy Markdown
Author

cc @kaixih

Comment thread flashinfer/gdn_decode.py
Comment thread flashinfer/gdn_decode.py
Comment thread flashinfer/gdn_decode.py
@Dayuxiaoshui
Copy link
Copy Markdown
Author

cc @kaixih

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
flashinfer/gdn_decode.py (5)

711-712: ⚠️ Potential issue | 🟠 Major

Add @backend_requirement guard to the unified public API.

This entrypoint is SM-capability constrained but currently only has @flashinfer_api, so unsupported devices fail late/inconsistently.

As per coding guidelines, "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 711 - 712, The public API
gated_delta_rule_decode is missing the backend capability guard: add the
`@backend_requirement` decorator to the gated_delta_rule_decode definition and
implement/ensure the referenced helpers are present—provide or reference
is_compute_capability_supported(cc) and is_backend_supported() so the decorator
can validate SM/compute capability before entry; update imports to bring in
backend_requirement and wire it on the gated_delta_rule_decode symbol so
unsupported devices are rejected early and consistently.

699-708: ⚠️ Potential issue | 🟠 Major

Allow negative state_indices for padding semantics in the unified API.

This helper rejects < 0 indices, which breaks the documented/expected padding semantics for pooled decode paths.

💡 Suggested fix
 def _check_state_indices_bounds(state_indices: torch.Tensor, pool_size: int) -> None:
-    """Validate that all state_indices are in [0, pool_size). Raises ValueError if not."""
+    """Validate that non-padding state_indices are < pool_size. Negative values are padding."""
     if state_indices.numel() == 0:
         return
-    bad = (state_indices < 0) | (state_indices >= pool_size)
+    bad = state_indices >= pool_size
     if bad.any().item():
         first_bad = state_indices[bad].flatten()[0].item()
         raise ValueError(
-            f"state_indices must be in [0, pool_size={pool_size}); got out-of-range value {first_bad}"
+            f"state_indices must be < pool_size={pool_size}; got out-of-range value {first_bad}"
         )

Also applies to: 754-757

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 699 - 708, The helper
_check_state_indices_bounds currently rejects negative indices but we need to
allow negative values for padding semantics; modify the validation to only flag
indices >= pool_size (remove the state_indices < 0 check), update the raised
ValueError message accordingly to reference only the upper bound, and apply the
same change to the duplicate occurrence (the other block around lines 754-757)
so pooled decode paths accept negative padding indices.

1016-1022: ⚠️ Potential issue | 🟠 Major

Do not deprecate gated_delta_rule_decode_kv immediately after introducing it.

This shim is the explicit KV migration path; warning on every call makes the migration target itself deprecated.

💡 Suggested fix
-def gated_delta_rule_decode_kv(
+def gated_delta_rule_decode_kv(
@@
-    """Deprecated: use gated_delta_rule_decode(..., state_layout=\"KV\") instead."""
-    warnings.warn(
-        "gated_delta_rule_decode_kv is deprecated and will be removed in a future "
-        "version. Use gated_delta_rule_decode(..., state_layout='KV') instead.",
-        DeprecationWarning,
-        stacklevel=2,
-    )
+    """KV-layout decode API."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1016 - 1022, The function
gated_delta_rule_decode_kv is emitting a DeprecationWarning on every call which
incorrectly marks the KV migration shim as deprecated; remove the
warnings.warn(...) deprecation call (the three-argument warnings.warn block that
mentions "gated_delta_rule_decode_kv is deprecated") from
gated_delta_rule_decode_kv so the shim remains the recommended migration
entrypoint (or replace it with a single informational log if you need
messaging), leaving the function implementation intact.

883-886: ⚠️ Potential issue | 🟠 Major

Route VK fp32 T==1 pool mode instead of raising NotImplementedError.

The unified router blocks a path that _gated_delta_rule_decode_pretranspose_impl already supports (pool + indices on pretranspose decode).

💡 Suggested fix
     if T == 1:
-        if use_pool:
-            raise NotImplementedError(
-                "VK fp32 T=1 with state_indices (pool) is not implemented yet"
-            )
+        if use_pool:
+            pool_size = state.shape[0]
+            if state.shape != (pool_size, HV, V, K):
+                raise ValueError(
+                    f"Expected state [pool_size, HV, V, K] for VK, got {state.shape}"
+                )
+            if state_indices.shape != (B,):
+                raise ValueError(f"state_indices must be [B={B}], got {state_indices.shape}")
+            _check_state_indices_bounds(state_indices, pool_size)
+            return _gated_delta_rule_decode_pretranspose_impl(
+                q=q, k=k, v=v, state=None, A_log=A_log, a=a, dt_bias=dt_bias, b=b,
+                scale=scale, output=output, use_qk_l2norm=use_qk_l2norm,
+                initial_state=state, initial_state_indices=state_indices,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 883 - 886, Remove the
NotImplementedError branch for VK fp32 when T==1 and route pool mode through the
existing implementation: replace the raise inside the if use_pool block with a
call/delegation to _gated_delta_rule_decode_pretranspose_impl (passing the same
inputs/state_indices/use_pool and any other local params needed) so the
pretranspose path that already supports pool + indices is used instead of
blocking; ensure the call is conditioned on the same VK fp32 and T==1 checks and
preserves return values and side effects.

968-970: ⚠️ Potential issue | 🟡 Minor

Reject ambiguous shim input when both state and initial_state are provided.

Current validation allows both and silently prefers initial_state, which can hide caller bugs.

💡 Suggested fix
     if state is None and initial_state is None:
         raise ValueError("Either state or initial_state must be provided")
+    if state is not None and initial_state is not None:
+        raise ValueError("Provide either state or initial_state, not both")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 968 - 970, The current validation in
gdn_decode.py allows both state and initial_state to be provided and silently
prefers initial_state; update the check in the function that handles these
parameters so that if both state and initial_state are not None it raises a
ValueError (e.g., "Provide either state or initial_state, not both"), otherwise
continue with the existing logic that uses the single provided value; reference
the parameters named state and initial_state to locate and change the
conditional that currently only checks for both being None and then checks
initial_state.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Line 784: The unpacking "B, T, H, K = q.shape" exposes an unused variable H;
update the unpack to prefix that unused component with an underscore (e.g., "B,
T, _H, K" or "B, T, _, K") so linters won’t flag it and intent is clearer;
modify the unpack expression where q.shape is destructured and ensure no other
references to H remain in the surrounding function.

---

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 711-712: The public API gated_delta_rule_decode is missing the
backend capability guard: add the `@backend_requirement` decorator to the
gated_delta_rule_decode definition and implement/ensure the referenced helpers
are present—provide or reference is_compute_capability_supported(cc) and
is_backend_supported() so the decorator can validate SM/compute capability
before entry; update imports to bring in backend_requirement and wire it on the
gated_delta_rule_decode symbol so unsupported devices are rejected early and
consistently.
- Around line 699-708: The helper _check_state_indices_bounds currently rejects
negative indices but we need to allow negative values for padding semantics;
modify the validation to only flag indices >= pool_size (remove the
state_indices < 0 check), update the raised ValueError message accordingly to
reference only the upper bound, and apply the same change to the duplicate
occurrence (the other block around lines 754-757) so pooled decode paths accept
negative padding indices.
- Around line 1016-1022: The function gated_delta_rule_decode_kv is emitting a
DeprecationWarning on every call which incorrectly marks the KV migration shim
as deprecated; remove the warnings.warn(...) deprecation call (the
three-argument warnings.warn block that mentions "gated_delta_rule_decode_kv is
deprecated") from gated_delta_rule_decode_kv so the shim remains the recommended
migration entrypoint (or replace it with a single informational log if you need
messaging), leaving the function implementation intact.
- Around line 883-886: Remove the NotImplementedError branch for VK fp32 when
T==1 and route pool mode through the existing implementation: replace the raise
inside the if use_pool block with a call/delegation to
_gated_delta_rule_decode_pretranspose_impl (passing the same
inputs/state_indices/use_pool and any other local params needed) so the
pretranspose path that already supports pool + indices is used instead of
blocking; ensure the call is conditioned on the same VK fp32 and T==1 checks and
preserves return values and side effects.
- Around line 968-970: The current validation in gdn_decode.py allows both state
and initial_state to be provided and silently prefers initial_state; update the
check in the function that handles these parameters so that if both state and
initial_state are not None it raises a ValueError (e.g., "Provide either state
or initial_state, not both"), otherwise continue with the existing logic that
uses the single provided value; reference the parameters named state and
initial_state to locate and change the conditional that currently only checks
for both being None and then checks initial_state.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4e26e059-7fcb-426e-8310-802ead52e1de

📥 Commits

Reviewing files that changed from the base of the PR and between 1a75db1 and 0649d07.

📒 Files selected for processing (3)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/__init__.py
  • flashinfer/gdn_decode.py
✅ Files skipped from review due to trivial changes (1)
  • benchmarks/bench_gdn_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/init.py

Comment thread flashinfer/gdn_decode.py
- Requires SM90+ (Hopper, Blackwell, etc.). All backends are JIT-compiled and tested on SM90/100/110/120.
- State is updated in-place; with pool (state_indices), updates write into the state tensor.
"""
B, T, H, K = q.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Prefix unused unpacked variable with _ to satisfy lint and improve clarity.

H is unpacked but never used in this function body.

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 784-784: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` at line 784, The unpacking "B, T, H, K = q.shape"
exposes an unused variable H; update the unpack to prefix that unused component
with an underscore (e.g., "B, T, _H, K" or "B, T, _, K") so linters won’t flag
it and intent is clearer; modify the unpack expression where q.shape is
destructured and ensure no other references to H remain in the surrounding
function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Unified GDN Decode/Prefill API

2 participants